# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access

import ast
import concurrent.futures
import logging
import time
from concurrent.futures import Future
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union

from azure.core.credentials import TokenCredential
from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import DataVersion, UriFileJobOutput
from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, is_registry_id_for_resource
from azure.ai.ml._utils._logger_utils import initialize_logger_info
from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, DefaultOpenEncoding, LROConfigurations
from azure.ai.ml.entities import BatchDeployment
from azure.ai.ml.entities._assets._artifacts.code import Code
from azure.ai.ml.entities._deployment.deployment import Deployment
from azure.ai.ml.entities._deployment.model_batch_deployment import ModelBatchDeployment
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException
from azure.ai.ml.operations._operation_orchestrator import OperationOrchestrator
from azure.core.exceptions import (
    ClientAuthenticationError,
    HttpResponseError,
    ResourceExistsError,
    ResourceNotFoundError,
    ServiceRequestTimeoutError,
    map_error,
)
from azure.core.polling import LROPoller
from azure.core.rest import HttpResponse
from azure.mgmt.core.exceptions import ARMErrorFormat

module_logger = logging.getLogger(__name__)
initialize_logger_info(module_logger, terminator="")


def check_default_deployment_template(deployment: Deployment, credential: Optional[TokenCredential]) -> None:
    """Check if a registry model has a default deployment template and log if found.

    :param Deployment deployment: Endpoint deployment object.
    :param credential: Credential for registry operations.
    :type credential: Optional[TokenCredential]
    """
    if not (credential and deployment.model and is_registry_id_for_resource(deployment.model)):
        return

    try:
        import re
        from azure.ai.ml.constants._common import REGISTRY_VERSION_PATTERN
        from azure.ai.ml.entities._assets._artifacts.model import Model
        from azure.ai.ml._utils._registry_utils import get_registry_client

        match = re.match(REGISTRY_VERSION_PATTERN, deployment.model, re.IGNORECASE)
        if not match:
            return

        registry_name = match.group(1)
        model_name = match.group(3)
        model_version = match.group(4)

        try:
            service_client, resource_group_name, _, _ = get_registry_client(
                credential=credential,
                registry_name=registry_name,
            )

            model_version_data = service_client.model_versions.get(
                name=model_name,
                version=model_version,
                registry_name=registry_name,
                resource_group_name=resource_group_name,
            )

            model = Model._from_rest_object(model_version_data)

            if hasattr(model, "default_deployment_template") and model.default_deployment_template:
                module_logger.info(
                    "\nModel '%s' (version %s) from registry '%s' has a "
                    "default deployment template configured.\n"
                    "The deployment will use the default deployment template settings, "
                    "and some deployment parameters may be ignored.\n"
                    "Default deployment template: %s\n",
                    model_name,
                    model_version,
                    registry_name,
                    model.default_deployment_template.asset_id,
                )
        except Exception:  # pylint: disable=broad-except
            pass
    except Exception:  # pylint: disable=broad-except
        pass


def get_duration(start_time: float) -> None:
    """Calculates the duration of the Long running operation took to finish.

    :param start_time: Start time
    :type start_time: float
    """
    end_time = time.time()
    duration = divmod(int(round(end_time - start_time)), 60)
    module_logger.warning("(%sm %ss)\n", duration[0], duration[1])


def polling_wait(
    poller: Union[LROPoller, Future],
    message: Optional[str] = None,
    start_time: Optional[float] = None,
    is_local=False,
    timeout=LROConfigurations.POLLING_TIMEOUT,
) -> Any:
    """Print out status while polling and time of operation once completed.

    :param poller: An poller which will return status update via function done().
    :type poller: Union[LROPoller, concurrent.futures.Future]
    :param (str, optional) message: Message to print out before starting operation write-out.
    :param (float, optional) start_time: Start time of operation.
    :param (bool, optional) is_local: If poller is for a local endpoint, so the timeout is removed.
    :param (int, optional) timeout: New value to overwrite the default timeout.
    """
    module_logger.warning("%s", message)
    if is_local:
        # We removed timeout on local endpoints in case it takes a long time
        # to pull image or install conda env.

        # We want user to be able to see that.

        while not poller.done():
            module_logger.warning(".")
            time.sleep(LROConfigurations.SLEEP_TIME)
    else:
        poller.result(timeout=timeout)

    if poller.done():
        module_logger.warning("Done ")
    else:
        module_logger.warning("Timeout waiting for long running operation")

    if start_time:
        get_duration(start_time)


def local_endpoint_polling_wrapper(func: Callable, message: str, **kwargs) -> Any:
    """Wrapper for polling local endpoint operations.

    :param func: Name of the endpoint.
    :type func: Callable
    :param message: Message to print out before starting operation write-out.
    :type message: str
    :return: The type returned by Func
    """
    pool = concurrent.futures.ThreadPoolExecutor()
    start_time = time.time()
    event = pool.submit(func, **kwargs)
    polling_wait(poller=event, start_time=start_time, message=message, is_local=True)
    return event.result()


def validate_response(response: HttpResponse) -> None:
    """Validates the response of POST requests, throws on error.

    :param HttpResponse response: the response of a POST requests
    :raises Exception: Raised when response is not json serializable
    :raises HttpResponseError: Raised when the response signals that an error occurred
    """
    r_json = {}

    if response.status_code not in [200, 201]:
        # We need to check for an empty response body or catch the exception raised.
        # It is possible the server responded with a 204 No Content response, and json parsing fails.
        if response.status_code != 204:
            try:
                r_json = response.json()
            except ValueError as e:
                # exception is not in the json format
                msg = response.content.decode("utf-8")
                raise MlException(message=msg, no_personal_data_message=msg) from e
        failure_msg = r_json.get("error", {}).get("message", response)
        error_map = {
            401: ClientAuthenticationError,
            404: ResourceNotFoundError,
            408: ServiceRequestTimeoutError,
            409: ResourceExistsError,
            424: HttpResponseError,
        }
        map_error(status_code=response.status_code, response=response, error_map=error_map)
        raise HttpResponseError(response=response, message=failure_msg, error_format=ARMErrorFormat)


def upload_dependencies(
    deployment: Deployment, orchestrators: OperationOrchestrator, credential: Optional[TokenCredential] = None
) -> None:
    """Upload code, dependency, model dependencies. For BatchDeployment only register compute.

    :param Deployment deployment: Endpoint deployment object.
    :param OperationOrchestrator orchestrators: Operation Orchestrator.
    :param credential: Optional credential for registry operations.
    :type credential: Optional[TokenCredential]
    """

    module_logger.debug("Uploading the dependencies for deployment %s", deployment.name)

    # Create a code asset if code is not already an ARM ID
    if (
        deployment.code_configuration
        and not is_ARM_id_for_resource(deployment.code_configuration.code, AzureMLResourceType.CODE)
        and not is_registry_id_for_resource(deployment.code_configuration.code)
    ):
        if deployment.code_configuration.code.startswith(ARM_ID_PREFIX):
            deployment.code_configuration.code = orchestrators.get_asset_arm_id(
                deployment.code_configuration.code[len(ARM_ID_PREFIX) :],
                azureml_type=AzureMLResourceType.CODE,
            )
        else:
            deployment.code_configuration.code = orchestrators.get_asset_arm_id(
                Code(base_path=deployment._base_path, path=deployment.code_configuration.code),
                azureml_type=AzureMLResourceType.CODE,
            )

    if not is_registry_id_for_resource(deployment.environment):
        deployment.environment = (
            orchestrators.get_asset_arm_id(deployment.environment, azureml_type=AzureMLResourceType.ENVIRONMENT)
            if deployment.environment
            else None
        )

    check_default_deployment_template(deployment, credential)

    if not is_registry_id_for_resource(deployment.model):
        deployment.model = (
            orchestrators.get_asset_arm_id(deployment.model, azureml_type=AzureMLResourceType.MODEL)
            if deployment.model
            else None
        )
    if isinstance(deployment, (BatchDeployment, ModelBatchDeployment)) and deployment.compute:
        deployment.compute = orchestrators.get_asset_arm_id(
            deployment.compute, azureml_type=AzureMLResourceType.COMPUTE
        )


def validate_scoring_script(deployment):
    score_script_path = Path(deployment.base_path).joinpath(
        deployment.code_configuration.code, deployment.scoring_script
    )
    try:
        with open(score_script_path, "r", encoding=DefaultOpenEncoding.READ) as script:
            contents = script.read()
            try:
                ast.parse(contents, score_script_path)
            except SyntaxError as err:
                err.filename = err.filename.split("/")[-1]
                msg = (
                    f"Failed to submit deployment {deployment.name} due to syntax errors "
                    f"in scoring script {err.filename}.\nError on line {err.lineno}: "
                    f"{err.text}\nIf you wish to bypass this validation use --skip-script-validation paramater."
                )

                np_msg = (
                    "Failed to submit deployment due to syntax errors in deployment script."
                    "\n If you wish to bypass this validation use --skip-script-validation paramater."
                )
                raise ValidationException(
                    message=msg,
                    target=(
                        ErrorTarget.BATCH_DEPLOYMENT
                        if isinstance(deployment, BatchDeployment)
                        else ErrorTarget.ONLINE_DEPLOYMENT
                    ),
                    no_personal_data_message=np_msg,
                    error_category=ErrorCategory.USER_ERROR,
                    error_type=ValidationErrorType.CANNOT_PARSE,
                ) from err
    except OSError as err:
        raise MlException(
            message=f"Failed to open scoring script {err.filename}.",
            no_personal_data_message="Failed to open scoring script.",
        ) from err


def convert_v1_dataset_to_v2(output_data_set: DataVersion, file_name: str) -> Dict[str, Any]:
    if file_name:
        v2_dataset = UriFileJobOutput(
            uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}/{file_name}"
        ).serialize()
    else:
        v2_dataset = UriFileJobOutput(
            uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}"
        ).serialize()
    return {"output_name": v2_dataset}
